{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ae75a385-88ca-4cdf-b430-b9928adffbd3",
   "metadata": {},
   "source": [
    "# Example 00 - The Official Sig53 Dataset\n",
    "This notebook walks through an example of how the official Sig53 dataset can be instantiated and analyzed.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f634e735-56bc-459f-a04f-5c4eadc5f8dd",
   "metadata": {},
   "source": [
    "## Import Libraries\n",
    "First, import all the necessary public libraries as well as a few classes from the `torchsig` toolkit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec50ccc1-1b10-45be-8bb2-48c623e3d579",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.utils.visualize import IQVisualizer, SpectrogramVisualizer\n",
    "from torchsig.datasets.sig53 import Sig53\n",
    "from torchsig.utils.dataset import SignalDataset\n",
    "from torchsig.datasets import conf\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.nn import Identity\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "from torchsig.datasets.datamodules import Sig53DataModule\n",
    "from torchsig.transforms.target_transforms import DescToClassIndexSNR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a901decc-9a77-4070-9d8d-ceb43af5d4af",
   "metadata": {},
   "source": [
    "### Instantiate Sig53 Dataset\n",
    "To instantiate the Sig53 dataset, several parameters are given to the imported `Sig53DataModule` class. These paramters are:\n",
    "- `root` ~ A string to specify the root directory of where to instantiate and/or read an existing Sig53 dataset\n",
    "- `impaired` ~ A boolean to specify if the Sig53 dataset should be the clean version or the impaired version\n",
    "- `qa` - A boolean to specify whether to generate a small subset of Sig53 (True), or the full dataset (False), default is True\n",
    "- `eb_no` - A boolean specifying if the SNR should be defined as Eb/No if True (making higher order modulations more powerful) or as Es/No if False (Defualt: False)\n",
    "- `transform` ~ Optionally, pass in any data transforms here if the dataset will be used in an ML training pipeline\n",
    "- `target_transform` ~ Optionally, pass in any target transforms here if the dataset will be used in an ML training pipeline\n",
    "\n",
    "A combination of the `impaired` and the `qa` booleans determines which of the four (4) distinct Sig53 datasets will be instantiated:\n",
    "| `impaired` | `qa` | Result |\n",
    "| ---------- | ---- | ------- |\n",
    "| `False` | `False` | Clean datasets of train=1.06M examples and val=5.3M examples |\n",
    "| `False` | `True` | Clean datasets of train=10600 examples and val=1060 examples |\n",
    "| `True` | `False` | Impaired datasets of train=1.06M examples and val=5.3M examples |\n",
    "| `True` | `True` | Impaired datasets of train=10600 examples and val=1060 examples |\n",
    "\n",
    "The final option of the impaired validation set is the dataset to be used when reporting any results with the official Sig53 dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "627e5be9-dd69-4df3-ab77-9bf1e35e6390",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Sig53 DataModule\n",
    "root = \"./datasets/wideband_sig53\"\n",
    "class_list = list(Sig53._idx_to_name_dict.values())\n",
    "num_workers = 4\n",
    "impaired = False\n",
    "\n",
    "datamodule = Sig53DataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    transform=Identity(),\n",
    "    target_transform=DescToClassIndexSNR(class_list),\n",
    "    num_workers=num_workers\n",
    ")\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "sig53 = datamodule.train\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(sig53))\n",
    "data, (label, snr) = sig53[idx]\n",
    "print(\"Dataset length: {}\".format(len(sig53)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "print(\"Label Index: {}\".format(label))\n",
    "print(\"Label Class: {}\".format(Sig53.convert_idx_to_name(label)))\n",
    "print(\"SNR: {}\".format(snr))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0815da97-fb34-4dcd-9c3d-560d573e1f27",
   "metadata": {},
   "source": [
    "## Plot Subset to Verify\n",
    "The `IQVisualizer` and the `SpectrogramVisualizer` can be passed a `Dataloader` and plot visualizations of the dataset. The `batch_size` of the `DataLoader` determines how many examples to plot for each iteration over the visualizer. Note that the dataset itself can be indexed and plotted sequentially using any familiar python plotting tools as an alternative plotting method to using the `torchsig` `Visualizer` as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bf58bd2-b157-4f33-be6a-af1dd7f5ba26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For plotting, omit the SNR values\n",
    "class DataWrapper(SignalDataset):\n",
    "    def __init__(self, dataset):\n",
    "        self.dataset = dataset\n",
    "        super().__init__(dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x, (y, z) = self.dataset[idx]\n",
    "        return x, y\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.dataset)\n",
    "\n",
    "\n",
    "plot_dataset = DataWrapper(sig53)\n",
    "\n",
    "data_loader = DataLoader(dataset=plot_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "# Transform the plotting titles from the class index to the name\n",
    "def target_idx_to_name(tensor: np.ndarray) -> list:\n",
    "    batch_size = tensor.shape[0]\n",
    "    label = []\n",
    "    for idx in range(batch_size):\n",
    "        label.append(Sig53.convert_idx_to_name(int(tensor[idx])))\n",
    "    return label\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39311b83-3900-4e47-9d90-0f78353f07ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualizer = IQVisualizer(\n",
    "    data_loader=data_loader,\n",
    "    visualize_transform=None,\n",
    "    visualize_target_transform=target_idx_to_name,\n",
    ")\n",
    "\n",
    "for figure in iter(visualizer):\n",
    "    figure.set_size_inches(14, 9)\n",
    "    # plt.savefig(f\"{figure_dir}/00_iq_data.png\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad8b620f-3b50-4bb8-bf75-821ec1af37fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Repeat but plot the spectrograms for a new random sampling of the data\n",
    "visualizer = SpectrogramVisualizer(\n",
    "    data_loader=data_loader,\n",
    "    nfft=1024,\n",
    "    visualize_transform=None,\n",
    "    visualize_target_transform=target_idx_to_name,\n",
    ")\n",
    "\n",
    "for figure in iter(visualizer):\n",
    "    figure.set_size_inches(14, 9)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49cadba0-d6f8-4dbe-9702-60c67aa1a855",
   "metadata": {},
   "source": [
    "## Analyze Dataset\n",
    "The dataset can also be analyzed at the macro level for details such as the distribution of classes and SNR values. This exercise is performed below to show the nearly uniform distribution across each."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc62b77-811e-4a2c-a20a-ded5d61080f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loop through the dataset recording classes and SNRs\n",
    "class_counter_dict = {\n",
    "    class_name: 0 for class_name in list(Sig53._idx_to_name_dict.values())\n",
    "}\n",
    "all_snrs = []\n",
    "\n",
    "for idx in tqdm(range(len(sig53))):\n",
    "    data, (modulation, snr) = sig53[idx]\n",
    "    class_counter_dict[Sig53.convert_idx_to_name(modulation)] += 1\n",
    "    all_snrs.append(snr)\n",
    "\n",
    "\n",
    "# Plot the distribution of classes\n",
    "class_names = list(class_counter_dict.keys())\n",
    "num_classes = list(class_counter_dict.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c9baa8-d7e5-422d-9703-9e701adfd028",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(9, 9))\n",
    "plt.pie(num_classes, labels=class_names)\n",
    "plt.title(\"Class Distribution Pie Chart\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a0fdaf4-8c48-48ca-a8bd-42cd8da3b5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(11, 4))\n",
    "plt.bar(class_names, num_classes)\n",
    "plt.xticks(rotation=90)\n",
    "plt.title(\"Class Distribution Bar Chart\")\n",
    "plt.xlabel(\"Modulation Class Name\")\n",
    "plt.ylabel(\"Counts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "501a871b-8b65-4fa4-a51e-e16853270f08",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of SNR values\n",
    "plt.figure(figsize=(11, 4))\n",
    "plt.hist(x=all_snrs, bins=100)\n",
    "plt.title(\"SNR Distribution\")\n",
    "plt.xlabel(\"SNR Bins (dB)\")\n",
    "plt.ylabel(\"Counts\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
